"""
charlm.jl: AutoGrad version (c) Emre Yolcu, 2016

This example implements an LSTM network for training character-level language
models. It takes as input a text file and trains the network to predict the
next character in a sequence. It can then be used to generate a sample that
resembles the original text.

To train the network, do `include("charlm.jl")` and run
`CharLM.train()`. Initial parameters can be provided as an optional argument to
`train`. Required form of the parameters can be understood by looking at the
definition of `weights` below. If not provided, default parameters are created
using `weights` with the appropriate arguments passed from `train`.

`train` accepts the following keyword arguments:

  - `datasrc`: Path to a file, or a URL. "The Complete Works of Shakespeare"
    from Project Gutenberg will be used by default.
  - `char_limit`: Maximum number of characters to read from `datasrc`
  - `epochs`: Number of epochs
  - `lr_init`: Initial learning rate
  - `lr_decay`: Learning rate decay
  - `decay_after`: Epoch after which to start decaying learning rate
  - `embedding_size`: Size of the embedding vector
  - `hidden_size`: Size of the LSTM internal state
  - `batch_size`: Number of sequences to train on in parallel
  - `sequence_length`: Number of steps to unroll the network for
  - `gclip`: Value to clip the gradient norm at
  - `pdrop`: Dropout probability
  - `seed`: Seed for the random number generator

At the end of each epoch, the cross entropy loss for the training data, the
learning rate, and the time taken by the epoch to complete are
printed. Optimized parameters and an array that maps integer indices to their
characters are returned at the end of training. Then, a sample text can be
generated by running `CharLM.generate(w, rv, n)` where `w` is the dictionary of
parameters, `rv` is the mapping of indices to characters, and `n` is the number
of characters to be generated.
"""

module CharLM

using AutoGrad
using Base.LinAlg: axpy!

sigm(x) = 1 ./ (1 + exp.(-x))
@primitive sigm(x),dy,y  (dy .* y .* (1 - y))
tanx(x) = tanh.(x)

function xavier(fan_out, fan_in)
    scale = sqrt(6 / (fan_in + fan_out))
    return convert(Array{Float32}, 2 * scale * rand(fan_out, fan_in) - scale)
end

function weights(; input_size=0, output_size=0, embedding_size=0,
                 hidden_size=0, batch_size=0, init=xavier)
    w = Dict()
    for gate in (:ingate, :forget, :outgate, :change)
        w[Symbol(:W_, gate)] = init(hidden_size, embedding_size + hidden_size)
        w[Symbol(:b_, gate)] = (gate == :forget ? ones : zeros)(Float32, (hidden_size, 1))
    end
    w[:W_embedding] = init(embedding_size, input_size)
    w[:W_predict]   = init(output_size, hidden_size)
    w[:b_predict]   = zeros(Float32, (output_size, 1))
    return w
end

function lstm(w, input, hidden, cell)
    x       = vcat(input, hidden) # avoid four separate vcat operations
    ingate  = sigm(w[:W_ingate]  * x .+ w[:b_ingate]) # in fact we can probably combine these four operations into one
    forget  = sigm(w[:W_forget]  * x .+ w[:b_forget]) # then use indexing, or (better) subarrays to get individual gates
    outgate = sigm(w[:W_outgate] * x .+ w[:b_outgate])
    change  = tanx(w[:W_change]  * x .+ w[:b_change]) 
    cell    = cell .* forget + ingate .* change
    hidden  = outgate .* tanx(cell)
    return hidden, cell
end

function predict(w, hidden)
    output = w[:W_predict] * hidden .+ w[:b_predict]
    return output .- log(sum(exp.(output), 1))
end

function dropout(x, pdrop)
    return x .* (rand(size(x)) .< (1 - pdrop)) / (1 - pdrop)
end

function forw(w, inputs; state=nothing, pdrop=0)
    batch_size = size(inputs[1], 2)
    hidden_size = size(w[:W_predict], 2)
    if state == nothing
        hidden = zeros(Float32, (hidden_size, batch_size))
        cell = zeros(Float32, (hidden_size, batch_size))
    else
        hidden, cell = state
    end
    outputs = Any[]
    for input in inputs
        hidden, cell = lstm(w, w[:W_embedding] * input, hidden, cell)
        pdrop > 0 && (hidden = dropout(hidden, pdrop))
        push!(outputs, predict(w, hidden))
    end
    return outputs, (hidden, cell)
end

function loss(w, inputs, targets; kwargs...)
    outputs = forw(w, inputs; kwargs...)[1]
    n = length(inputs)
    z = 0.0
    for t = 1:n
        z += sum(outputs[t] .* targets[t])
    end
    return -z / size(inputs[1], 2)
end

function gnorm(g)
    return mapreduce(vecnorm, +, 0, values(g))
end

function train(w=nothing; datasrc=nothing, char_limit=0, epochs=1, lr_init=1.0,
               lr_decay=0.95, decay_after=10, embedding_size=128,
               hidden_size=256, batch_size=50, sequence_length=50, gclip=5.0,
               pdrop=0, seed=0)
    seed > -1 && srand(seed)
    data, chars, char_to_index, index_to_char = loaddata(datasrc, batch_size, char_limit)
    vocab_size = length(char_to_index)
    if w == nothing
        w = weights(; input_size=vocab_size, output_size=vocab_size,
                    embedding_size=embedding_size, hidden_size=hidden_size,
                    batch_size=batch_size)
    end
    gradfun = grad(loss)

    for epoch = 1:epochs
        start_time = time()
        targets = Any[]
        inputs = Any[]
        loss_count = zeros(2)
        lr = lr_init * lr_decay^max(0, epoch - decay_after)
        T = length(data) - 1
        for t = 1:T
            push!(inputs, copy(data[t])) # why copy here? there is no overwriting in AutoGrad.
            push!(targets, copy(data[t + 1]))
            if (t % sequence_length == 0) || t == T
                loss_count[1] += loss(w, inputs, targets; pdrop=pdrop)
                loss_count[2] += length(inputs)
                g = gradfun(w, inputs, targets; pdrop=pdrop)
                gn = (gclip > 0 ? gnorm(g) : 0)
                gscale = (gn > gclip > 0 ? (gclip / gn) : 1)
                for p in keys(w)
                    axpy!(-lr * gscale, g[p], w[p])
                end
                empty!(inputs)
                empty!(targets)
            end
            if t % 1000 == 0
                elapsed_time = time() - start_time
                @printf(STDERR, "Epoch: %d, t: %d/%d, Loss: %.6f, LR: %.6f, Time: %.6f\n",
                        epoch, t, T, loss_count[1] / loss_count[2], lr, elapsed_time)
            end
        end
    end

    return w, index_to_char
end

function loaddata(datasrc, batch_size, char_limit=0)
    if datasrc == nothing
        datasrc = joinpath(AutoGrad.datapath, "pg100.txt")
    end
    if !isfile(datasrc)
        url = "http://www.gutenberg.org/cache/epub/100/pg100.txt"
        download(url,datasrc)
    end
    stream = open(datasrc)
    chars = Char[]
    char_to_index = Dict{Char, Int32}()
    while !eof(stream)
        c = read(stream, Char)
        get!(char_to_index, c, 1 + length(char_to_index))
        push!(chars, c)
        char_limit > 0 && length(chars) >= char_limit && break
    end
    info("Read: $(length(chars)) characters, $(length(char_to_index)) vocabulary")
    data = minibatch(chars, char_to_index, batch_size)
    index_to_char = Array{Char}(length(char_to_index))
    for (c, i) in char_to_index
        index_to_char[i] = c
    end
    return data, chars, char_to_index, index_to_char
end

function minibatch(chars, char_to_index, batch_size)
    nbatch = div(length(chars), batch_size)
    data = Any[]
    for i = 1:nbatch
        d = zeros(Float32, (length(char_to_index), batch_size))
        for j = 1:batch_size
            d[char_to_index[chars[i + nbatch * (j - 1)]], j] = 1
        end
        push!(data, d)
    end
    return data
end

function generate(w, index_to_char, nchar)
    vocab_size = length(index_to_char)
    x = zeros(Float32, (vocab_size, 1))
    s = nothing
    j = 1
    for i = 1:nchar
        y, s = forw(w, Any[x]; state=s)
        x[j, 1] = 0
        j = sample(exp(y[1]))
        x[j, 1] = 1
        print(index_to_char[j])
    end
    println()
end

function sample(p)
    r = rand(Float32)
    for c = 1:length(p)
        r -= p[c]
        r < 0 && return c
    end
end

end  # module
